The notebook for this lesson uses the below JPEG photograph of tulips as the “dataset.” The notebook will access the data used to encode the image and use it to perform modeling.
import numpy as np
import pandas as pd
%matplotlib inline
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
img = plt.imread('using_kmeans_for_color_compression_tulips_photo.jpg')
print(img.shape)
plt.imshow(img)
plt.axis('off')
(320, 240, 3)
(-0.5, 239.5, 319.5, -0.5)
# Reshape the image so that each row represents a single pixel
# defined by three values: R, G, B
img_flat = img.reshape(img.shape[0]*img.shape[1],3)
img_flat[:5,:]
array([[211, 197, 38],
[199, 181, 21],
[178, 154, 0],
[185, 152, 0],
[184, 145, 0]], dtype=uint8)
img_flat.shape
(76800, 3)
# Create a pandas df with r, g, and b as columns
img_flat_df = pd.DataFrame(img_flat, columns = ['r', 'g', 'b'])
img_flat_df.head()
| r | g | b | |
|---|---|---|---|
| 0 | 211 | 197 | 38 |
| 1 | 199 | 181 | 21 |
| 2 | 178 | 154 | 0 |
| 3 | 185 | 152 | 0 |
| 4 | 184 | 145 | 0 |
# Create 3D plot where each pixel in the `img` is displayed in its actual color
trace = go.Scatter3d(x = img_flat_df.r,
y = img_flat_df.g,
z = img_flat_df.b,
mode='markers',
marker=dict(size=1,
color=['rgb({},{},{})'.format(r,g,b) for r,g,b
in zip(img_flat_df.r.values,
img_flat_df.g.values,
img_flat_df.b.values)],
opacity=0.5))
data = [trace]
layout = go.Layout(margin=dict(l=0,
r=0,
b=0,
t=0),
)
fig = go.Figure(data=data, layout=layout)
fig.update_layout(scene = dict(
xaxis_title='R',
yaxis_title='G',
zaxis_title='B'),
)
fig.show()
# instantiate the model
kmeans = KMeans(n_clusters=1, n_init='auto',random_state=19991209).fit(img_flat)
img_flat1 = img_flat.copy()
for i in np.unique(kmeans.labels_):
img_flat1[kmeans.labels_==i,:] = kmeans.cluster_centers_[i]
img1 = img_flat1.reshape(img.shape)
plt.imshow(img1)
plt.axis('off');
The result is the image of our tulips when every pixel is replaced with the average color. The average color of this photo was brown—all the colors muddled together.
# Calculate mean of each column in the flattened array
column_means = img_flat.mean(axis=0)
print('column means: ', column_means)
column means: [125.60802083 78.90632813 43.45473958]
trace = go.Scatter3d(x = img_flat_df.r,
y = img_flat_df.g,
z = img_flat_df.b,
mode='markers',
marker=dict(size=1,
color=['rgb({},{},{})'.format(r,g,b) for
r,g,b in zip(img_flat_df.r.values,
img_flat_df.g.values,
img_flat_df.b.values)],
opacity=0.5))
data = [trace]
layout = go.Layout(margin=dict(l=0,
r=0,
b=0,
t=0))
fig = go.Figure(data=data, layout=layout)
# Add centroid to chart
centroid = kmeans.cluster_centers_[0].tolist()
fig.add_trace(
go.Scatter3d(x = [centroid[0]],
y = [centroid[1]],
z = [centroid[2]],
mode='markers',
marker=dict(size=7,
color=['rgb(125.79706706,78.8178776,42.58090169)'],
opacity=1))
)
fig.update_layout(scene = dict(
xaxis_title='R',
yaxis_title='G',
zaxis_title='B'),
)
fig.show()
kmeans3 = KMeans(n_clusters=3,n_init='auto', random_state=19991209).fit(img_flat)
# Check the unique values of what's returned by the .labels_ attribute
np.unique(kmeans3.labels_)
array([0, 1, 2])
# Assign centroid coordinates to `centers` variable
centers = kmeans3.cluster_centers_
centers
array([[ 41.11904835, 50.27093234, 15.9247325 ],
[202.68983875, 173.15223957, 109.8380343 ],
[176.32140539, 42.10443038, 27.27284161]])
def show_swatch(RGB_value):
'''
Takes in an RGB value and outputs a color swatch
'''
R, G, B = RGB_value
rgb = [[np.array([R,G,B]).astype('uint8')]]
plt.figure()
plt.imshow(rgb)
plt.axis('off');
# Display the color swatches
for pixel in centers:
show_swatch(pixel)
def cluster_image(k, img=img):
'''
Fits a K-means model to a photograph.
Replaces photo's pixels with RGB values of model's centroids.
Displays the updated image.
Args:
k: (int) - Your selected K-value
img: (numpy array) - Your original image converted to a numpy array
Returns:
The output of plt.imshow(new_img), where new_img is a new numpy array \
where each row of the original array has been replaced with the \
coordinates of its nearest centroid.
'''
img_flat = img.reshape(img.shape[0]*img.shape[1], 3)
kmeans = KMeans(n_clusters = k,n_init='auto', random_state = 42).fit(img_flat)
new_img = img_flat.copy()
for i in np.unique(kmeans.labels_):
new_img[kmeans.labels_ == i, :] = kmeans.cluster_centers_[i]
new_img = new_img.reshape(img.shape)
return plt.imshow(new_img), plt.axis('off');
cluster_image(3);
print(kmeans3.labels_.shape)
print(kmeans3.labels_)
print(np.unique(kmeans3.labels_))
print(kmeans3.cluster_centers_)
(76800,) [1 1 1 ... 2 2 2] [0 1 2] [[ 41.11904835 50.27093234 15.9247325 ] [202.68983875 173.15223957 109.8380343 ] [176.32140539 42.10443038 27.27284161]]
# Create a new column in the df that indicates the cluster number of each row
# (as assigned by Kmeans for k=3)
img_flat_df['cluster'] = kmeans3.labels_
img_flat_df.head()
| r | g | b | cluster | |
|---|---|---|---|---|
| 0 | 211 | 197 | 38 | 1 |
| 1 | 199 | 181 | 21 | 1 |
| 2 | 178 | 154 | 0 | 1 |
| 3 | 185 | 152 | 0 | 1 |
| 4 | 184 | 145 | 0 | 2 |
# Create helper dictionary to map RGB color values to each observation in df
series_conversion = {0: 'rgb' +str(tuple(kmeans3.cluster_centers_[0])),
1: 'rgb' +str(tuple(kmeans3.cluster_centers_[1])),
2: 'rgb' +str(tuple(kmeans3.cluster_centers_[2])),
}
series_conversion
{0: 'rgb(41.119048349015216, 50.27093233589929, 15.924732501454713)',
1: 'rgb(202.68983875095677, 173.15223957000086, 109.83803429741378)',
2: 'rgb(176.32140538786075, 42.10443037974727, 27.27284160986635)'}
# Replace the cluster numbers in the 'cluster' col with formatted RGB values
# (made ready for plotting)
img_flat_df['cluster'] = img_flat_df['cluster'].map(series_conversion)
img_flat_df.head()
| r | g | b | cluster | |
|---|---|---|---|---|
| 0 | 211 | 197 | 38 | rgb(202.68983875095677, 173.15223957000086, 10... |
| 1 | 199 | 181 | 21 | rgb(202.68983875095677, 173.15223957000086, 10... |
| 2 | 178 | 154 | 0 | rgb(202.68983875095677, 173.15223957000086, 10... |
| 3 | 185 | 152 | 0 | rgb(202.68983875095677, 173.15223957000086, 10... |
| 4 | 184 | 145 | 0 | rgb(176.32140538786075, 42.10443037974727, 27.... |
trace = go.Scatter3d(x = img_flat_df.r,
y = img_flat_df.g,
z = img_flat_df.b,
mode='markers',
marker=dict(size=1,
color=img_flat_df.cluster,
opacity=1))
data = trace
layout = go.Layout(margin=dict(l=0,
r=0,
b=0,
t=0))
fig = go.Figure(data=data, layout=layout)
fig.show()
def cluster_image_grid(k, ax, img=img):
'''
Fits a K-means model to a photograph.
Replaces photo's pixels with RGB values of model's centroids.
Displays the updated image on an axis of a figure.
Args:
k: (int) - Your selected K-value
ax: (int) - Index of the axis of the figure to plot to
img: (numpy array) - Your original image converted to a numpy array
Returns:
A new image where each row of img's array has been replaced with the \
coordinates of its nearest centroid. Image is assigned to an axis that \
can be used in an image grid figure.
'''
img_flat = img.reshape(img.shape[0]*img.shape[1], 3)
kmeans = KMeans(n_clusters=k, n_init='auto', random_state=42).fit(img_flat)
new_img = img_flat.copy()
for i in np.unique(kmeans.labels_):
new_img[kmeans.labels_==i, :] = kmeans.cluster_centers_[i]
new_img = new_img.reshape(img.shape)
ax.imshow(new_img)
ax.axis('off')
fig, axs = plt.subplots(3, 3) # Create 3 x 3 Canvas
fig = matplotlib.pyplot.gcf()
fig.set_size_inches(9, 12)
axs = axs.flatten()
k_values = np.arange(2, 11)
for i, k in enumerate(k_values):
cluster_image_grid(k, axs[i], img=img)
axs[i].title.set_text('k=' + str(k))